import pandas as pd
import sqlalchemy

from sqlalchemy import or_, and_, text
from sqlalchemy.exc import IntegrityError

from sqltool.models import (
    Session,
    User,
    Address,
    Book,
    Author,
    Department,
    Employee,
    Code,
)


def or_filters(*args):
    return or_(*args)


def and_filters(*args):
    return and_(*args)


class SqlHelper:
    def __init__(self):
        self.session = Session()

    def add(self, obj):
        self.session.add(obj)
        self.session.commit()

    def delete(self, obj):
        self.session.delete(obj)
        self.session.commit()

    def update(self, obj):
        self.session.merge(obj)
        self.session.commit()

    def upsert(self, model):
        self._upsert(model)
        self.session.commit()

    def _upsert(self, model):
        """
        根据唯一索引查找记录，再决定更新(补全primary_key的自增列)或新增记录
        """
        unique = self.get_unique_keys(model)
        obj = (
            self.session.query(model.__class__)
            .filter(
                # getattr(model.__class__, primary_key) == getattr(model, primary_key),
                *[(getattr(model.__class__, c) == getattr(model, c)) for c in unique]
            )
            .first()
        )
        if obj:
            primary_key = self.get_primary_key(model)
            val = getattr(obj, primary_key)
            setattr(model, primary_key, val)
            self.session.merge(model)
        else:
            self.session.add(model)

    def add_all(self, models):
        try:
            self.session.begin()
            self.session.add_all(models)
            self.session.commit()
        except IntegrityError:
            # 如果出现唯一键冲突错误，回滚事务
            self.session.rollback()

    def update_all(self, models):
        try:
            for m in models:
                self.session.merge(m)
            self.session.commit()
        except IntegrityError:
            self.session.rollback()

    def delete_all(self, models):
        try:
            for m in models:
                self.session.delete(m)
            self.session.commit()
        except IntegrityError:
            self.session.rollback()

    def upsert_all(self, models):
        try:
            self.session.begin()
            for m in models:
                self._upsert(m)
            self.session.commit()
        except IntegrityError:
            self.session.rollback()

    def query(self, model):
        return self.session.query(model)

    def update_by(self, model, filters, col_value_map):
        return self.session.query(model).filter(filters).update(col_value_map)

    def query_by(self, model, *args):
        return self.session.query(model).filter(and_(*args))

    def query_or(self, model, *args):
        return self.session.query(model).filter(or_(*args))

    def query_by_func(self, model, func):
        """
        使用 or_filters 和 and_filters 叠加查询条件
        demo
            and_filter(or_filters(User.age < 30, User.age > 60), User.name == 'lisi_6')
            等同于
            (User.age < 30 or User.age > 60) and User.name == 'lisi_6'
        """
        return self.session.query(model).filter(func)

    def get_by(self, model, **kwargs):
        return self.session.query(model).filter_by(**kwargs).first()

    def count_by(self, model, **kwargs):
        return self.session.query(model).filter_by(**kwargs).count()

    def exists(self, model, **kwargs):
        return self.session.query(model).filter_by(**kwargs).exists()

    def join(self, model1, model2):
        return self.session.query(model1).join(model2)

    def join_ex(self, model1, model2):
        return self.session.query(model1, model2)

    def first(self, query):
        return query.first()

    def commit(self):
        self.session.commit()

    def close(self):
        self.session.commit()
        self.session.close()

    def get_primary_key(self, model):
        for col in model.__table__.columns:
            if col.primary_key:
                return col.name

    def get_unique_keys(self, model) -> list[str]:
        """获取唯一键列"""
        # 1.从列中unique属性获取列名
        unique = [c.name for c in model.__table__.columns if c.unique]
        # 2.从__table_args__自定义索引(Index)中获取unique的列名
        idxs = list(model.__table__.indexes)
        for idx in idxs:
            if not idx.unique:
                continue
            for c in idx.columns:
                unique.append(c.name)
        return unique

    def execute_sql(self, sqlstr) -> list[tuple]:
        return self.session.execute(text(sqlstr)).fetchall()

    def read_from(self, sql) -> pd.DataFrame:
        """
        查询sql语句，必须使用text()
        pandas 需要connection，不是engine
        """
        return pd.read_sql_query(text(sql), self.session.bind.connect())

    def read_model(self, model) -> pd.DataFrame:
        """
        查询sql语句，必须使用text()
        pandas 需要connection，不是engine
        """
        return pd.read_sql(
            self.session.query(model).statement, self.session.bind.connect()
        )


def demo():
    sql_helper = SqlHelper()
    count = sql_helper.count_by(Address)
    addr0 = Address()
    addr0.address = f"浙江杭州_{count + 1}"
    sql_helper.add(addr0)
    # addr = sql_helper.query(Address).first()
    user = User()
    user.addr_id = addr0.id
    user.age = 33
    user.name = f"lisi_{count + 1}"
    sql_helper.add(user)

    users = sql_helper.query_by(User, User.name == "Tom").all()
    users = sql_helper.query(User).filter(User.age > 30).all()
    users = sql_helper.query(User).filter(User.name.like("a%")).all()
    tom = sql_helper.get_by(User, name="Tom")
    count = sql_helper.count_by(User)
    exists = sql_helper.exists(User, age=30)
    first_user = sql_helper.first(sql_helper.query(User))
    query_result = (
        sql_helper.join_ex(Address, User).filter(User.addr_id == Address.id).all()
    )
    for address, user in query_result:
        print(user.id, user.name, address.id, address.address)
    # query_result = sql_helper.query(User).join(Address).filter(User.addr_id == Address.id).all()
    query_result = sql_helper.query(User).join(Address).all()
    for result in query_result:
        print(result)

    # 连接
    author1 = Author(name="Jam")
    author2 = Author(name="Sam")
    sql_helper.add(author1)
    sql_helper.add(author2)
    book1 = Book(title="C Language", author_id=author1.id)
    book2 = Book(title="C++ Language", author_id=author1.id)
    book3 = Book(title="Python", author_id=author2.id)
    book4 = Book(title="JavaScript", author_id=author2.id)
    sql_helper.add(book1)
    sql_helper.add(book2)
    sql_helper.add(book3)
    sql_helper.add(book4)

    # n--1 连接
    books = sql_helper.query(Book).join(Author).all()
    for book in books:
        print(f"Book: {book.title} Author:{book.author.name}")

    authors = sql_helper.query(Author).join(Book).all()
    for author in authors:
        for book in author.books:
            print(f"Author: {author.name} Book: {book.title}")

    # 1--n 连接
    department = Department(name="Sales")
    sql_helper.add(department)
    employee1 = Employee(name="John", department=department, remark="test1111")
    employee2 = Employee(name="Jane", department=department, remark="test2222")
    sql_helper.add(employee1)
    sql_helper.add(employee2)
    dept = sql_helper.query(Department).filter_by(name="Sales").first()
    employees = dept.employees
    for employee in employees:
        print(employee.name)

    # 单条件批量更新
    rst = sql_helper.update_by(
        Employee, Employee.name == "John", {Employee.remark: employee1.remark}
    )
    sql_helper.commit()
    # 多条件批量更新
    rst = sql_helper.update_by(
        Employee,
        or_filters(Employee.name == "John", Employee.name == "Jane"),
        {Employee.remark: employee2.remark},
    )
    sql_helper.commit()
    print("------------------------------------------------------------")
    # 复合唯一索引的存在更新否则插入
    code1 = Code(symbol="SA309", exchange="CZCE", remark="test11111")
    code2 = Code(symbol="SA309", exchange="CZCE", remark="test22222")
    # sql_helper.add(code1)
    sql_helper.upsert(code1)
    sql_helper.upsert(code2)
    print("------------------------------------------------------------")
    # or 查询
    users = sql_helper.query_or(
        User, User.age > 55, User.age < 30, User.age == 44, User.name == "lisi_40"
    ).all()
    for user in users:
        print(user)
    print("------------------------------------------------------------")
    # and 查询
    users = sql_helper.query_by(User, User.name == "lisi_6", User.age > 45)
    for user in users:
        print(user)
    print("------------------------------------------------------------")
    # 多条件嵌套查询 or + and
    users = sql_helper.query_by_func(
        User,
        and_filters(or_filters(User.age < 30, User.age > 60), User.name == "lisi_6"),
    ).all()
    for user in users:
        print(user)
    print("------------------------------------------------------------")
    result = sql_helper.execute_sql("select * from users;")
    for item in result:
        print(item)
    print("------------------------------------------------------------")
    df = sql_helper.read_model(User)
    print(df.head())
    print("------------------------------------------------------------")
    df = sql_helper.read_from("select * from users;")
    print(df.tail())
    print("------------------------------------------------------------")
    result = sql_helper.execute_sql("select * from employee;")
    for item in result:
        print(item)
    print("------------------------------------------------------------")
    sql_helper.close()


if __name__ == "__main__":
    print(sqlalchemy.__version__)
    demo()
